Estimating a PEM survival model with Gamma baseline hazardΒΆ

%load_ext autoreload
import random
import survivalstan
import numpy as np
import pandas as pd
from stancache import stancache
from matplotlib import pyplot as plt
/*  Variable naming:
 // dimensions
 N          = total number of observations (length of data)
 S          = number of sample ids
 T          = max timepoint (number of timepoint ids)
 M          = number of covariates

 // data
 s          = sample id for each obs
 t          = timepoint id for each obs
 event      = integer indicating if there was an event at time t for sample s
 x          = matrix of real-valued covariates at time t for sample n [N, X]
 obs_t      = observed end time for interval for timepoint for that obs

// Jacqueline Buros Novik <>

data {
  int<lower=1> N;
  int<lower=1> S;
  int<lower=1> T;
  int<lower=0> M;
  int<lower=1, upper=N> s[N];     // sample id
  int<lower=1, upper=T> t[N];     // timepoint id
  int<lower=0, upper=1> event[N]; // 1: event, 0:censor
  matrix[N, M] x;                 // explanatory vars
  real<lower=0> obs_t[N];         // observed end time for each obs
  real<lower=0> t_dur[T];
  real<lower=0> t_obs[T];
transformed data {
  real c_unit;
  real r_unit;
  int n_trans[S, T];

  // scale for baseline hazard params (fixed)
  c_unit = 0.001;
  r_unit = 0.1;

  // n_trans used to map each sample*timepoint to n (used in gen quantities)
  // map each patient/timepoint combination to n values
  for (n in 1:N) {
      n_trans[s[n], t[n]] = n;

  // fill in missing values with n for max t for that patient
  // ie assume "last observed" state applies forward (may be problematic for TVC)
  // this allows us to predict failure times >= observed survival times
  for (samp in 1:S) {
      int last_value;
      last_value = 0;
      for (tp in 1:T) {
          // manual says ints are initialized to neg values
          // so <=0 is a shorthand for "unassigned"
          if (n_trans[samp, tp] <= 0 && last_value != 0) {
              n_trans[samp, tp] = last_value;
          } else {
              last_value = n_trans[samp, tp];
parameters {
  vector<lower=0>[T] baseline; // unstructured baseline hazard for each timepoint t
  vector[M] beta; // beta for each covariate
  real<lower=0> c_raw;
  real<lower=0> r_raw;
transformed parameters {
  vector[N] log_hazard;
  vector[T] log_baseline;
  real<lower=0> c;
  real<lower=0> r;

  log_baseline = log(baseline);

  r = r_unit*r_raw;
  c = c_unit*c_raw;

  for (n in 1:N) {
    log_hazard[n] = x[n,]*beta + log_baseline[t[n]];
model {
  for (i in 1:T) {
      baseline[i] ~ gamma(r * t_dur[i] * c, c);
  beta ~ cauchy(0, 2);
  event ~ poisson_log(log_hazard);
  c_raw ~ normal(0, 1);
  r_raw ~ normal(0, 1);
generated quantities {
  real log_lik[N];
  int y_hat_mat[S, T]; // ppcheck for each S*T combination
  real y_hat_time[S];       // predicted failure time for each sample
  int y_hat_event[S];      // predicted event (0:censor, 1:event)

  // log-likelihood, for loo
  for (n in 1:N) {
      log_lik[n] = poisson_log_lpmf(event[n] | log_hazard[n]);

  // posterior predicted values
  for (samp in 1:S) {
      int sample_alive;
      sample_alive = 1;
      for (tp in 1:T) {
        if (sample_alive == 1) {
              real log_haz;
              int n;
              int pred_y;

              // determine predicted value of y
              n = n_trans[samp, tp];
              log_haz = x[n,]*beta + log_baseline[tp];
              if (log_haz < log(pow(2, 30)))
                  pred_y = poisson_log_rng(log_haz);
                  pred_y = 9;

              // mark this patient as ineligible for future tps
              // note: deliberately make 9s ineligible
              if (pred_y >= 1) {
                  sample_alive = 0;
                  y_hat_time[samp] = t_obs[tp];
                  y_hat_event[samp] = 1;

              // save predicted value of y to matrix
              y_hat_mat[samp, tp] = pred_y;
          else if (sample_alive == 0) {
              y_hat_mat[samp, tp] = 9;
      } // end per-timepoint loop

      // if patient still alive at max
      if (sample_alive == 1) {
          y_hat_time[samp] = t_obs[T];
          y_hat_event[samp] = 0;
  } // end per-sample loop

d = stancache.cached(
    rate_form='1 + sex',
    rate_coefs=[-3, 0.5],
d['age_centered'] = d['age'] - d['age'].mean()
INFO:stancache.stancache:sim_data_exp_correlated: cache_filename set to sim_data_exp_correlated.cached.N_100.censor_time_20.rate_coefs_54462717316.rate_form_1 + sex.pkl
INFO:stancache.stancache:sim_data_exp_correlated: Loading result from cache
age sex rate true_t t event index age_centered
0 59 male 0.082085 20.948771 20.000000 False 0 4.18
1 58 male 0.082085 12.827519 12.827519 True 1 3.18
2 61 female 0.049787 27.018886 20.000000 False 2 6.18
3 57 female 0.049787 62.220296 20.000000 False 3 2.18
4 55 male 0.082085 10.462045 10.462045 True 4 0.18
survivalstan.utils.plot_observed_survival(df=d[d['sex']=='female'], event_col='event', time_col='t', label='female')
survivalstan.utils.plot_observed_survival(df=d[d['sex']=='male'], event_col='event', time_col='t', label='male')
<matplotlib.legend.Legend at 0x7f9fac008eb8>
dlong = stancache.cached(
    df=d, event_col='event', time_col='t'
INFO:stancache.stancache:prep_data_long_surv: cache_filename set to prep_data_long_surv.cached.df_33772694934.event_col_event.time_col_t.pkl
INFO:stancache.stancache:prep_data_long_surv: Loading result from cache
age sex rate true_t t event index age_centered key end_time end_failure
0 59 male 0.082085 20.948771 20.0 False 0 4.18 1 20.000000 False
1 59 male 0.082085 20.948771 20.0 False 0 4.18 1 12.827519 False
2 59 male 0.082085 20.948771 20.0 False 0 4.18 1 10.462045 False
3 59 male 0.082085 20.948771 20.0 False 0 4.18 1 0.196923 False
4 59 male 0.082085 20.948771 20.0 False 0 4.18 1 9.244121 False
testfit = survivalstan.fit_stan_survival_model(
    model_cohort = 'test model',
    model_code = survivalstan.models.pem_survival_model_gamma,
    df = dlong,
    sample_col = 'index',
    timepoint_end_col = 'end_time',
    event_col = 'end_failure',
    formula = '~ age_centered + sex',
    iter = 5000,
    chains = 4,
    seed = 9001,
    FIT_FUN = stancache.cached_stan_fit,

INFO:stancache.stancache:Step 1: Get compiled model code, possibly from cache
INFO:stancache.stancache:StanModel: cache_filename set to anon_model.cython_0_25_1.model_code_72990130769.pystan_2_12_0_0.stanmodel.pkl
INFO:stancache.stancache:StanModel: Loading result from cache
INFO:stancache.stancache:Step 2: Get posterior draws from model, possibly from cache
INFO:stancache.stancache:sampling: cache_filename set to anon_model.cython_0_25_1.model_code_72990130769.pystan_2_12_0_0.stanfit.chains_4.data_64545635565.iter_5000.seed_9001.pkl
INFO:stancache.stancache:sampling: Starting execution
INFO:stancache.stancache:sampling: Execution completed (0:02:26.153391 elapsed)
INFO:stancache.stancache:sampling: Saving results to cache
/home/jacquelineburos/miniconda3/envs/python3/lib/python3.5/site-packages/stancache/ UserWarning: Pickling fit objects is an experimental feature!
The relevant StanModel instance must be pickled along with this fit object.
When unpickling the StanModel must be unpickled first.
  pickle.dump(res, open(cache_filepath, 'wb'), pickle.HIGHEST_PROTOCOL)
/home/jacquelineburos/miniconda3/envs/python3/lib/python3.5/site-packages/stanity/ FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison
  elif sort == 'in-place':
/home/jacquelineburos/miniconda3/envs/python3/lib/python3.5/site-packages/stanity/ VisibleDeprecationWarning: using a non-integer number instead of an integer will result in an error in the future
  bs /= 3 * x[sort[np.floor(n/4 + 0.5) - 1]]
/home/jacquelineburos/miniconda3/envs/python3/lib/python3.5/site-packages/stanity/ RuntimeWarning: overflow encountered in exp
  np.exp(temp, out=temp)
/home/jacquelineburos/miniconda3/envs/python3/lib/python3.5/site-packages/stanity/ RuntimeWarning: divide by zero encountered in true_divide
  bs /= 3 * x[sort[np.floor(n/4 + 0.5) - 1]]
/home/jacquelineburos/miniconda3/envs/python3/lib/python3.5/site-packages/stanity/ RuntimeWarning: invalid value encountered in multiply
  temp = ks[:,None] * x
/home/jacquelineburos/miniconda3/envs/python3/lib/python3.5/site-packages/stanity/ RuntimeWarning: invalid value encountered in greater_equal
  dii = w >= 10 * np.finfo(float).eps
/home/jacquelineburos/miniconda3/envs/python3/lib/python3.5/site-packages/stanity/ RuntimeWarning: invalid value encountered in double_scalars
  sigma = -k / b
survivalstan.utils.print_stan_summary([testfit], pars='lp__')
            mean  se_mean        sd         2.5%          50%       97.5%      Rhat
lp__ -1034.96497  0.18312  7.354494 -1050.563484 -1034.543604 -1021.53023  1.000819
In [9]:
survivalstan.utils.print_stan_summary([testfit], pars='log_baseline')
                        mean   se_mean          sd        2.5%         50%      97.5%      Rhat
log_baseline[0]    -5.540533  0.026349    1.317162   -8.741271   -5.317482  -3.605955  1.000660
log_baseline[1]    -5.540373  0.026410    1.312295   -8.765512   -5.305697  -3.661229  1.000160
log_baseline[2]    -5.533987  0.025182    1.277367   -8.563403   -5.341926  -3.613802  1.000899
log_baseline[3]    -5.499061  0.023197    1.266738   -8.526801   -5.300919  -3.598318  1.000996
log_baseline[4]    -5.529027  0.029142    1.376488   -8.883094   -5.280803  -3.553426  1.001002
log_baseline[5]    -5.471047  0.025521    1.260655   -8.429872   -5.272214  -3.596805  1.002028
log_baseline[6]    -5.499416  0.034998    1.321609   -8.548750   -5.291451  -3.584369  1.003210
log_baseline[7]    -5.464198  0.024734    1.288292   -8.466082   -5.228081  -3.575405  1.000619
log_baseline[8]    -5.466310  0.027385    1.284197   -8.583571   -5.262634  -3.536853  1.001117
log_baseline[9]    -5.416356  0.024357    1.279819   -8.469762   -5.221550  -3.528211  1.001357
log_baseline[10]   -5.440701  0.027773    1.306535   -8.676018   -5.228972  -3.541475  1.001608
log_baseline[11]   -5.434051  0.027698    1.305317   -8.664646   -5.225114  -3.512998  0.999949
log_baseline[12]   -5.370823  0.028533    1.284318   -8.459077   -5.161179  -3.454565  1.001136
log_baseline[13]   -5.393917  0.024404    1.280480   -8.626593   -5.192589  -3.486274  1.000423
log_baseline[14]   -5.364330  0.025443    1.273700   -8.327442   -5.166064  -3.490194  1.002779
log_baseline[15]   -5.357747  0.025662    1.282847   -8.510490   -5.145634  -3.463617  1.002548
log_baseline[16]   -5.378575  0.027974    1.302219   -8.521636   -5.152384  -3.489239  1.003528
log_baseline[17]   -5.295181  0.026170    1.249062   -8.246395   -5.115480  -3.442103  1.004022
log_baseline[18]   -5.346658  0.027128    1.301594   -8.462238   -5.129009  -3.403453  1.002558
log_baseline[19]   -5.293308  0.027616    1.280204   -8.270132   -5.103891  -3.405552  1.000711
log_baseline[20]   -5.312085  0.024626    1.318119   -8.478142   -5.086743  -3.394770  1.000745
log_baseline[21]   -5.277845  0.026557    1.264723   -8.303782   -5.068627  -3.349496  1.000715
log_baseline[22]   -5.306103  0.027747    1.309989   -8.480791   -5.098819  -3.372876  1.002474
log_baseline[23]   -5.240126  0.024418    1.274211   -8.374662   -5.031179  -3.368183  1.002631
log_baseline[24]   -5.180366  0.023456    1.229393   -8.199401   -4.991827  -3.339248  1.002037
log_baseline[25]   -5.220781  0.025189    1.279213   -8.155421   -5.026628  -3.347359  0.999862
log_baseline[26]   -5.198388  0.027133    1.293306   -8.334439   -4.988685  -3.310959  1.002848
log_baseline[27]   -5.145882  0.024486    1.269478   -8.201074   -4.960108  -3.261169  1.000358
log_baseline[28]   -5.112184  0.023308    1.235079   -8.159944   -4.913842  -3.245869  1.001038
log_baseline[29]   -5.130311  0.028653    1.303939   -8.381742   -4.914959  -3.271805  1.001287
log_baseline[30]   -5.197454  0.026676    1.342558   -8.362530   -4.956216  -3.213312  1.000961
log_baseline[31]   -5.188365  0.027795    1.344816   -8.474813   -4.969168  -3.277665  1.002005
log_baseline[32]   -5.102617  0.024780    1.260118   -8.178528   -4.901117  -3.213594  1.000780
log_baseline[33]   -5.075221  0.024387    1.227845   -7.991722   -4.885576  -3.214388  1.000436
log_baseline[34]   -5.110409  0.029528    1.336298   -8.447845   -4.872393  -3.193095  1.001111
log_baseline[35]   -5.031990  0.024814    1.248103   -8.051273   -4.841412  -3.165030  1.000698
log_baseline[36]   -5.050843  0.025298    1.288200   -8.189802   -4.835042  -3.140867  1.002413
log_baseline[37]   -5.027862  0.025181    1.274591   -7.993631   -4.829429  -3.130617  1.002734
log_baseline[38]   -5.043253  0.031843    1.332076   -8.392334   -4.800582  -3.117772  1.001792
log_baseline[39]   -5.009148  0.027024    1.286125   -8.114673   -4.801411  -3.083061  1.001948
log_baseline[40]   -4.977261  0.022727    1.253310   -8.021432   -4.786938  -3.099191  1.000618
log_baseline[41]   -4.973488  0.026404    1.297564   -8.165326   -4.762586  -3.063354  1.000620
log_baseline[42]   -4.976570  0.026160    1.289808   -8.134780   -4.785349  -3.061472  1.000685
log_baseline[43]   -4.919019  0.025599    1.277389   -8.050034   -4.700221  -3.016657  1.001715
log_baseline[44]   -4.900334  0.025927    1.293227   -8.049396   -4.679596  -2.980225  1.000656
log_baseline[45]   -4.907134  0.025745    1.276647   -8.046290   -4.690973  -3.021486  1.001499
log_baseline[46]   -4.855415  0.023345    1.255215   -7.826891   -4.654950  -2.972896  1.001151
log_baseline[47]   -4.832234  0.021958    1.254851   -7.809152   -4.639160  -2.926565  1.001058
log_baseline[48]   -4.798222  0.024395    1.255120   -7.870438   -4.580276  -2.924569  1.000580
log_baseline[49]   -4.838834  0.029046    1.346472   -8.135627   -4.612250  -2.866454  1.001520
log_baseline[50]   -4.760367  0.024182    1.265330   -7.845615   -4.552539  -2.875869  1.000354
log_baseline[51]   -4.737922  0.024380    1.249825   -7.791359   -4.524863  -2.860117  1.000065
log_baseline[52]   -4.751358  0.028862    1.327670   -8.044268   -4.540209  -2.862803  1.000918
log_baseline[53]   -4.685420  0.023722    1.284065   -7.852333   -4.469523  -2.795754  1.001032
log_baseline[54]   -4.695464  0.025316    1.294588   -7.810316   -4.503222  -2.782197  1.000505
log_baseline[55]   -4.662545  0.027774    1.317164   -7.848580   -4.440736  -2.771303  1.000641
log_baseline[56]   -4.669330  0.030262    1.321870   -7.914412   -4.434960  -2.743924  1.000618
log_baseline[57]   -4.620962  0.025575    1.282324   -7.585949   -4.400851  -2.737336  1.003473
log_baseline[58]   -4.564812  0.025214    1.287904   -7.671804   -4.338766  -2.704413  1.001306
log_baseline[59]   -4.562128  0.030414    1.350609   -7.910983   -4.324704  -2.648129  1.000886
log_baseline[60]   -4.542367  0.032122    1.338757   -7.768437   -4.320009  -2.624694  1.002409
log_baseline[61]   -4.480764  0.025349    1.287803   -7.550987   -4.280852  -2.593134  1.001754
log_baseline[62]   -4.430831  0.024307    1.228657   -7.382741   -4.239083  -2.589439  1.001206
log_baseline[63]   -4.395725  0.023673    1.273268   -7.384986   -4.209973  -2.479231  1.000442
log_baseline[64]   -4.346682  0.025621    1.266635   -7.426116   -4.141308  -2.442702  1.002046
log_baseline[65]   -4.312951  0.024879    1.298957   -7.481565   -4.094432  -2.411871  1.001782
log_baseline[66]   -4.272058  0.023535    1.271993   -7.300271   -4.066569  -2.414926  1.001380
log_baseline[67]   -4.278248  0.024941    1.301696   -7.406780   -4.063855  -2.381025  1.001246
log_baseline[68]   -4.239232  0.022731    1.255980   -7.282352   -4.042431  -2.363243  1.000641
log_baseline[69]   -4.220383  0.023864    1.257355   -7.227408   -4.033891  -2.300163  1.000607
log_baseline[70]   -4.174461  0.033438    1.387179   -7.389505   -3.925402  -2.265936  1.002038
log_baseline[71]   -4.090127  0.022568    1.252697   -7.085192   -3.892866  -2.207512  1.002348
log_baseline[72]   -4.110447  0.027878    1.315899   -7.399702   -3.860361  -2.205056  1.001732
log_baseline[73]   -4.056023  0.027877    1.311094   -7.247058   -3.837001  -2.137026  1.001115
log_baseline[74]   -4.023769  0.024615    1.255600   -7.067530   -3.820071  -2.142365  1.001948
log_baseline[75]   -4.049963  0.026498    1.295967   -7.201009   -3.854413  -2.140042  1.000684
log_baseline[76]   -4.047071  0.028314    1.357597   -7.297987   -3.803428  -2.082792  1.001065
log_baseline[77] -322.669893  5.131864  200.339993 -685.976286 -304.462026 -18.223910  1.002720
survivalstan.utils.plot_stan_summary([testfit], pars='baseline')
INFO:survivalstan.utils:Warning - 1 rows removed due to NaN values for Rhat. This may indicate a problem in your model estimation.
survivalstan.utils.plot_coefs([testfit], element='baseline')
survivalstan.utils.plot_pp_survival([testfit], fill=False)
survivalstan.utils.plot_observed_survival(df=d, event_col='event', time_col='t', color='green', label='observed')
<matplotlib.legend.Legend at 0x7f9ec3bcba20>
survivalstan.utils.plot_pp_survival([testfit], by='sex')
ppsurv = survivalstan.utils.prep_pp_survival_data([testfit], by='sex')
subplot = plt.subplots(1, 1)
survivalstan.utils._plot_pp_survival_data(ppsurv.query('sex == "male"').copy(),
                                          subplot=subplot, color='blue', alpha=0.3)
survivalstan.utils._plot_pp_survival_data(ppsurv.query('sex == "female"').copy(),
                                          subplot=subplot, color='red', alpha=0.3)
survivalstan.utils.plot_observed_survival(df=d[d['sex']=='female'], event_col='event', time_col='t',
                                          color='red', label='female')
survivalstan.utils.plot_observed_survival(df=d[d['sex']=='male'], event_col='event', time_col='t',
                                          color='blue', label='male')
<matplotlib.legend.Legend at 0x7f9f9714d7f0>
